How to Implement Your Own Estimators¶
All the estimators in all packages follow the pre-defined protocols based on their types. All the implementations of algorithms which follow the protocols in s3l.base
can be evaluated as the built-in algorithms by experiment classes.
The estimators should inherit a base estimator class in s3l.base
according to the type of the estimator you are going to implement. We currently provide five options for you:
TransductiveEstimatorwithGraph
,TransductiveEstimatorWOGraph
,InductiveEstimatorWOGraph
,InductiveEstimatorwithGraph
,SupervisedEstimator
.
As the names indicate, the experiments support supervised learning algorithms, semi-supervised learning algorithms in both inductive and transductive settings with or without graph.
For each estimator class, you must implement the following methods: set_params
, fit
and predict
.
set_params
is the methods to configure the parameters of the estimator objects given a dict storing the values of some parameters. It’s called in the experiments to search for the best hyper-parameters. Since the object is used repeatly with different hyper-parameters, you should make sure that the object is reset as if hadn’t been trained. A common implementation is as follows.
def set_params(self, param):
"""Parameter setting function.
Parameters
----------
param:dict
Store parameter names and corresponding values {'name': value}.
"""
if isinstance(param, dict):
self.__dict__.update(param)
# Codes to reset some properties which may influence the
# prediction.
fit
is the method to train the model given data; predict
is the method to make prediction. The main difference between base classes is the parameters of the fit
and predict
. For transductive estimator, the predict
method takes in the indexes of instances to predict (the estimator can see the testing data when training). For inductive estimator, the predict
method takes in the features of instances to predict. fit
method always takes X, y, l_ind, and optional args are supported. For graph-based algorithms, W must be provided for fit
method.
For supervised learning algorithm, you can inherit SupervisedEstimator
class. You must rewrite __init__
method and initialize the member model as an object of supervised learning model, and model must have the following methods:
class SupervisedEstimator(BaseEstimator):
""" Supervised estimator of single-label task.
"""
@abstractmethod
def __init__(self):
super(SupervisedEstimator, self).__init__()
self.model = None
def fit(self, X, y, l_ind=None, **kwargs):
"""
Takes X, y, label_index.
"""
if l_ind is not None:
X = X[l_ind, :]
if y.ndim == 2:
y = y[l_ind, :].reshape(-1)
else:
y = y[l_ind]
self.model.fit(X, y)
def predict(self, X, **kwargs):
"""
Takes X
"""
return self.model.predict(X)
def set_params(self, param):
self.model.set_params(**param)
def predict_proba(self, X):
return self.model.predict_proba(X)
def predict_log_proba(self, X):
return self.model.predict_log_proba(X)
s3l.wrapper.sklearn_wrapper
guides you to wrap any supervised learning algorithm you like.
Attention¶
Sometimes your estimator class may contain C-language object member. The object of estimator can be un-serializable when the C object has pointers because the python interpreter has no way to know the details of the memory where the pointer points to.
The experiment classes run the experiemnts in multi-process mode when n_jobs
is set larger than 1, which requires the estimator object is serializable. An option is to rewrite the __getstate__
and __setstate__
methods to design the way how estimator object is dumped and loaded by pickle
. The simplest way is to drop the un-picklable member in __getstate__
and re-initialze it in __setstate__
. Here is an example taken from s3l.classification.TSVM
where self.model is a C object:
def __getstate__(self):
"""
The model is ctypes objects and contains pointers cannot be pickled.
So we drop the model when we pickle TSVM.
"""
state = self.__dict__.copy()
del state['model'] # manually delete
return state
def __setstate__(self, state):
"""
The model is ctypes objects and contains pointers cannot be pickled.
So we drop the model when we pickle TSVM.
"""
self.__dict__.update(state)
self.model = None # manually update